{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Clone TriplaneGuassian repo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!git clone https://github.com/VAST-AI-Research/TriplaneGaussian.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd TriplaneGaussian"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install dependencies. It may take a while."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -r requirements.txt\n",
    "\n",
    "# install pointnet2_ops_lib\n",
    "%cd tgs/models/snowflake/pointnet2_ops_lib\n",
    "!python setup.py install\n",
    "%cd ../../../..\n",
    "\n",
    "# install pytorch_scatter\n",
    "import sys\n",
    "import torch\n",
    "!pip install torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+121.html\n",
    "\n",
    "# install diff-gaussian-rasterization\n",
    "!apt-get install libglm-dev\n",
    "!git clone https://github.com/graphdeco-inria/diff-gaussian-rasterization.git\n",
    "%cd diff-gaussian-rasterization\n",
    "!python setup.py install\n",
    "%cd .."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install Pytorch3d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import torch\n",
    "pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
    "version_str=\"\".join([\n",
    "    f\"py3{sys.version_info.minor}_cu\",\n",
    "    torch.version.cuda.replace(\".\",\"\"),\n",
    "    f\"_pyt{pyt_version_str}\"\n",
    "])\n",
    "!pip install fvcore iopath\n",
    "!pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# download SAM checkpoint\n",
    "!mkdir checkpoints\n",
    "%cd checkpoints\n",
    "!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth\n",
    "%cd .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python infer.py --config config.yaml data.image_list=[example_images/a_pikachu_with_smily_face.webp,] --image_preprocess"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Display the rendered video"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import HTML\n",
    "from base64 import b64encode\n",
    "def display_video(video_path):\n",
    "  mp4 = open(video_path,'rb').read()\n",
    "  data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
    "  return HTML(\"\"\"\n",
    "  <video width=400 controls>\n",
    "    <source src=\"%s\" type=\"video/mp4\">\n",
    "  </video>\n",
    "  \"\"\" % data_url)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_dir = './outputs/video'\n",
    "\n",
    "import os\n",
    "import glob\n",
    "video_path = glob.glob(os.path.join(save_dir, \"*.mp4\"))[0]\n",
    "display_video(video_path)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
